#  Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
import datetime
from typing import Dict, List, Union

from loguru import logger
from pyparsing import Forward, Group, Suppress, Word, alphas, delimitedList, nums

from dcs_core.core.common.models.metric import MetricsType, MetricValue
from dcs_core.core.metric.base import Metric, MetricIdentity


class CombinedMetric(Metric):
    """
    CombinedMetric is a class that represents a metric test is generated by a data source.
    """

    def _find_metric_values(
        self, metric_arg: any, metric_values: List[MetricValue]
    ) -> float:
        """
        Find the metric values for the metric identities
        """
        if metric_arg.isnumeric():
            return float(metric_arg)
        elif isinstance(metric_arg, str):
            for metric_value in metric_values:
                if metric_value.tags["metric_name"] == metric_arg:
                    return metric_value.value
            raise ValueError(f"Metric {metric_arg} not found in a {self.expression}")

    def _metric_expression_parser(
        self, expression_str: str, metric_values: List[MetricValue]
    ) -> Dict:
        try:
            fwd_expr = Forward()
            operation = Word(alphas)
            arguments = Word("_" + nums + alphas)
            lp, rp = map(Suppress, "()")
            reg_expression = operation + lp + delimitedList(fwd_expr) + rp
            fwd_expr << (reg_expression | arguments)
            fwd_expr.setParseAction(
                lambda tokens: self._find_metric_values(tokens[0], metric_values)
                if len(tokens) == 1
                else {"operation": tokens[0], "args": tokens[1:]}
            )
            return fwd_expr.parseString(expression_str, parseAll=True)[0]
        except Exception as e:
            raise ValueError(f"Invalid expression {expression_str}", e)

    def _perform_operation(self, operation_data):
        """
        Perform the operation specified in the operation_data
        """
        operation = operation_data["operation"]
        args = operation_data["args"]
        if len(args) > 2:
            raise ValueError("Operation must have only two arguments")
            return
        if isinstance(args[0], dict):
            args[0] = self._perform_operation(args[0])

        if isinstance(args[1], dict):
            args[1] = self._perform_operation(args[1])

        if operation == "sum":
            return args[0] + args[1]
        elif operation == "sub":
            return args[0] - args[1]
        elif operation == "mul":
            return args[0] * args[1]
        elif operation == "div":
            return args[0] / args[1]
        else:
            raise ValueError("Invalid operation")

    def get_metric_identity(self):
        return MetricIdentity.generate_identity(
            metric_type=MetricsType.COMBINED,
            metric_name=self.name,
            expression=self.expression,
        )

    def _generate_metric_value(self, metric_values: List[MetricValue]):
        """
        Generate the metric value for this metric
        """
        expression_data = self._metric_expression_parser(self.expression, metric_values)
        return round(self._perform_operation(expression_data), 2)
